import sys
import adafruit_bno055
import board
import busio
import numpy as np
import os
import pickle

from queue import Queue
from threading import Thread
import time

from rich.live import Live
from rich.table import Table


CALIBRATE_FILE = 'imu_calib_data_250922.pkl'


# TODO filter spikes
class Imu:
    def __init__(
        self, sampling_freq, user_pitch_bias=0, calibrate=False, upside_down=True
    ):
        self.sampling_freq = sampling_freq
        self.calibrate = calibrate

        i2c = busio.I2C(board.SCL, board.SDA)
        self.imu = adafruit_bno055.BNO055_I2C(i2c)
        
        # self.imu.mode = adafruit_bno055.IMUPLUS_MODE
        # self.imu.mode = adafruit_bno055.ACCGYRO_MODE
        # self.imu.mode = adafruit_bno055.GYRONLY_MODE
        self.imu.mode = adafruit_bno055.NDOF_MODE
        # self.imu.mode = adafruit_bno055.NDOF_FMC_OFF_MODE

        if upside_down:
            self.imu.axis_remap = (
                adafruit_bno055.AXIS_REMAP_Y,
                adafruit_bno055.AXIS_REMAP_X,
                adafruit_bno055.AXIS_REMAP_Z,
                adafruit_bno055.AXIS_REMAP_NEGATIVE,
                adafruit_bno055.AXIS_REMAP_NEGATIVE,
                adafruit_bno055.AXIS_REMAP_NEGATIVE,
            )

        else:
            self.imu.axis_remap = (
                adafruit_bno055.AXIS_REMAP_Y,
                adafruit_bno055.AXIS_REMAP_X,
                adafruit_bno055.AXIS_REMAP_Z,
                adafruit_bno055.AXIS_REMAP_NEGATIVE,
                adafruit_bno055.AXIS_REMAP_POSITIVE,
                adafruit_bno055.AXIS_REMAP_POSITIVE,
            )

        if self.calibrate: # Perform initial calibration work!!!
            self.imu.mode = adafruit_bno055.NDOF_MODE
            calibrated = self.imu.calibrated
            while not calibrated:
                print("Calibration status: ", self.imu.calibration_status)
                print("Calibrated : ", self.imu.calibrated)
                calibrated = self.imu.calibrated
                time.sleep(0.1)
            print("CALIBRATION DONE")
            offsets_accelerometer = self.imu.offsets_accelerometer
            offsets_gyroscope = self.imu.offsets_gyroscope
            offsets_magnetometer = self.imu.offsets_magnetometer

            imu_calib_data = {
                "offsets_accelerometer": offsets_accelerometer,
                "offsets_gyroscope": offsets_gyroscope,
                "offsets_magnetometer": offsets_magnetometer,
            }
            for k, v in imu_calib_data.items():
                print(k, v)

            pickle.dump(imu_calib_data, open("imu_calib_data.pkl", "wb"))

            print("Saved", "imu_calib_data.pkl")
        else: # Use IMU with CALIBRATE_FILE
            if os.path.exists(CALIBRATE_FILE):
                imu_calib_data = pickle.load(open(CALIBRATE_FILE, "rb"))
                self.imu.mode = adafruit_bno055.CONFIG_MODE
                time.sleep(0.1)
                self.imu.offsets_accelerometer = imu_calib_data["offsets_accelerometer"]
                self.imu.offsets_gyroscope = imu_calib_data["offsets_gyroscope"]
                self.imu.offsets_magnetometer = imu_calib_data["offsets_magnetometer"]
                self.imu.mode = adafruit_bno055.NDOF_MODE
                time.sleep(0.1)
                print('IMU data loaded!!!')
            else:
                print(f"Calibrate file {CALIBRATE_FILE} not found")
                print("Imu is running uncalibrated")
                exit()

            self.x_offset = 0

            # self.tare_x()

            self.last_imu_data = [0, 0, 0, 0]
            self.last_imu_data = {
                "gyro": [0, 0, 0],
                "accelero": [0, 0, 0],
            }
            self.imu_queue = Queue(maxsize=1)
            Thread(target=self.imu_worker, daemon=True).start()

    def tare_x(self):
        print("Taring x ...")
        x_values = []
        num_values = 100
        ok = False
        while not ok:
            x_values.append(np.array(self.imu.acceleration)[0])

            x_values = x_values[-num_values:]

            if len(x_values) == num_values:
                mean = np.mean(x_values)
                std = np.std(x_values)
                if std < 0.05:
                    ok = True
                    self.x_offset = mean
                    print("Tare x done")
                else:
                    print(std)

            time.sleep(0.01)

    def imu_worker(self):
        while True:
            s = time.time()
            try:
                gyro = np.array(self.imu.gyro).copy()
                accelero = np.array(self.imu.acceleration).copy()
            except Exception as e:
                print("[IMU]:", e)
                continue

            if gyro is None or accelero is None:
                continue

            if gyro.any() is None or accelero.any() is None:
                continue

            accelero[0] -= self.x_offset

            data = {
                "gyro": gyro,
                "accelero": accelero,
            }

            self.imu_queue.put(data)
            took = time.time() - s
            time.sleep(max(0, 1 / self.sampling_freq - took))

    def get_data(self):
        try:
            self.last_imu_data = self.imu_queue.get(False)  # non blocking
        except Exception:
            pass

        return self.last_imu_data



def float2str(num):
    res = f'{num:.3f}'
    
    # 转换为字符串并处理正负号
    if num >= 0.:
        res = "+" + res
    
    return res




if __name__ == "__main__":
    from mini_bdx_runtime.keyboard_utils import is_any_key_pressed

    imu = Imu(50, upside_down=True)
    
    with Live(refresh_per_second=30) as live:
        while True:
            data = imu.get_data()

            table = Table(title = 'IMU Sensor Data', show_header=True)

            table.add_column('Type', width = 15)
            table.add_column('X', width = 10)
            table.add_column('Y', width = 10)
            table.add_column('Z', width = 10)

            gyro = np.around(data["gyro"], 3)
            table.add_row('gyro', float2str(gyro[0]), float2str(gyro[1]), float2str(gyro[2]))

            accelero = np.around(data["accelero"], 3)
            table.add_row('accelero', float2str(accelero[0]), float2str(accelero[1]), float2str(accelero[2]))

            live.update(table)
            
            if is_any_key_pressed(['\x1b', 'q']):
                break

            time.sleep(1 / 25)
